{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import urllib.request\n",
    "from urllib.parse import urlparse\n",
    "import gzip\n",
    "import argparse\n",
    "import mindspore.dataset as ds\n",
    "import mindspore.nn as nn\n",
    "from mindspore import context\n",
    "from mindspore.train.serialization import load_checkpoint, load_param_into_net\n",
    "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor\n",
    "from mindspore.train import Model\n",
    "from mindspore.common.initializer import TruncatedNormal\n",
    "import mindspore.dataset.transforms.vision.c_transforms as CV\n",
    "import mindspore.dataset.transforms.c_transforms as C\n",
    "from mindspore.dataset.transforms.vision import Inter\n",
    "from mindspore.nn.metrics import Accuracy\n",
    "from mindspore.common import dtype as mstype\n",
    "from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def unzipfile(gzip_path):\n",
    "    \"\"\"unzip dataset file\n",
    "    Args:\n",
    "        gzip_path: dataset file path\n",
    "    \"\"\"\n",
    "    open_file = open(gzip_path.replace('.gz',''), 'wb')\n",
    "    gz_file = gzip.GzipFile(gzip_path)\n",
    "    open_file.write(gz_file.read())\n",
    "    gz_file.close()\n",
    "\n",
    "\n",
    "def download_dataset():\n",
    "    \"\"\"Download the dataset from http://yann.lecun.com/exdb/mnist/.\"\"\"\n",
    "    print(\"******Downloading the MNIST dataset******\")\n",
    "    train_path = \"./MNIST_Data/train/\"\n",
    "    test_path = \"./MNIST_Data/test/\"\n",
    "    train_path_check = os.path.exists(train_path)\n",
    "    test_path_check = os.path.exists(test_path)\n",
    "    if train_path_check == False and test_path_check ==False:\n",
    "        os.makedirs(train_path)\n",
    "        os.makedirs(test_path)\n",
    "    train_url = {\"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\", \"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\"}\n",
    "    test_url = {\"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\", \"http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\"}\n",
    "    for url in train_url:\n",
    "        url_parse = urlparse(url)\n",
    "        # split the file name from url\n",
    "        file_name = os.path.join(train_path,url_parse.path.split('/')[-1])\n",
    "        if not os.path.exists(file_name.replace('.gz','')):\n",
    "            file = urllib.request.urlretrieve(url, file_name)\n",
    "            unzipfile(file_name)\n",
    "            os.remove(file_name)\n",
    "    for url in test_url:\n",
    "        url_parse = urlparse(url)\n",
    "        # split the file name from url\n",
    "        file_name = os.path.join(test_path,url_parse.path.split('/')[-1])\n",
    "        if not os.path.exists(file_name.replace('.gz','')):\n",
    "            file = urllib.request.urlretrieve(url, file_name)\n",
    "            unzipfile(file_name)\n",
    "            os.remove(file_name)\n",
    "\n",
    "\n",
    "def create_dataset(data_path, batch_size=32, repeat_size=1,\n",
    "                   num_parallel_workers=1):\n",
    "    \"\"\" create dataset for train or test\n",
    "    Args:\n",
    "        data_path: Data path\n",
    "        batch_size: The number of data records in each group\n",
    "        repeat_size: The number of replicated data records\n",
    "        num_parallel_workers: The number of parallel workers\n",
    "    \"\"\"\n",
    "    # define dataset\n",
    "    mnist_ds = ds.MnistDataset(data_path)\n",
    "\n",
    "    # define operation parameters\n",
    "    resize_height, resize_width = 32, 32\n",
    "    rescale = 1.0 / 255.0\n",
    "    shift = 0.0\n",
    "    rescale_nml = 1 / 0.3081\n",
    "    shift_nml = -1 * 0.1307 / 0.3081\n",
    "\n",
    "    # define map operations\n",
    "    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)  # Resize images to (32, 32)\n",
    "    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) # normalize images\n",
    "    rescale_op = CV.Rescale(rescale, shift) # rescale images\n",
    "    hwc2chw_op = CV.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network.\n",
    "    type_cast_op = C.TypeCast(mstype.int32) # change data type of label to int32 to fit network\n",
    "\n",
    "    # apply map operations on images\n",
    "    mnist_ds = mnist_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=num_parallel_workers)\n",
    "    mnist_ds = mnist_ds.map(input_columns=\"image\", operations=resize_op, num_parallel_workers=num_parallel_workers)\n",
    "    mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n",
    "    mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n",
    "    mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n",
    "\n",
    "    # apply DatasetOps\n",
    "    buffer_size = 10000\n",
    "    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)  # 10000 as in LeNet train script\n",
    "    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n",
    "    mnist_ds = mnist_ds.repeat(repeat_size)\n",
    "\n",
    "    return mnist_ds\n",
    "\n",
    "\n",
    "def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):\n",
    "    \"\"\"Conv layer weight initial.\"\"\"\n",
    "    weight = weight_variable()\n",
    "    return nn.Conv2d(in_channels, out_channels,\n",
    "                     kernel_size=kernel_size, stride=stride, padding=padding,\n",
    "                     weight_init=weight, has_bias=False, pad_mode=\"valid\")\n",
    "\n",
    "\n",
    "def fc_with_initialize(input_channels, out_channels):\n",
    "    \"\"\"Fc layer weight initial.\"\"\"\n",
    "    weight = weight_variable()\n",
    "    bias = weight_variable()\n",
    "    return nn.Dense(input_channels, out_channels, weight, bias)\n",
    "\n",
    "\n",
    "def weight_variable():\n",
    "    \"\"\"Weight initial.\"\"\"\n",
    "    return TruncatedNormal(0.02)\n",
    "\n",
    "\n",
    "class LeNet5(nn.Cell):\n",
    "    \"\"\"Lenet network structure.\"\"\"\n",
    "    # define the operator required\n",
    "    def __init__(self):\n",
    "        super(LeNet5, self).__init__()\n",
    "        self.batch_size = 32\n",
    "        self.conv1 = conv(1, 6, 5)\n",
    "        self.conv2 = conv(6, 16, 5)\n",
    "        self.fc1 = fc_with_initialize(16 * 5 * 5, 120)\n",
    "        self.fc2 = fc_with_initialize(120, 84)\n",
    "        self.fc3 = fc_with_initialize(84, 10)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.flatten = nn.Flatten()\n",
    "\n",
    "    # use the preceding operators to construct networks\n",
    "    def construct(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.max_pool2d(x)\n",
    "        x = self.conv2(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.max_pool2d(x)\n",
    "        x = self.flatten(x)\n",
    "        x = self.fc1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb):\n",
    "    \"\"\"Define the training method.\"\"\"\n",
    "    print(\"============== Starting Training ==============\")\n",
    "    # load training dataset\n",
    "    ds_train = create_dataset(os.path.join(mnist_path, \"train\"), 32, repeat_size)\n",
    "    model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False)\n",
    "\n",
    "\n",
    "def test_net(args, network, model, mnist_path):\n",
    "    \"\"\"Define the evaluation method.\"\"\"\n",
    "    print(\"============== Starting Testing ==============\")\n",
    "    # load the saved model for evaluation\n",
    "    param_dict = load_checkpoint(\"checkpoint_lenet-1_1875.ckpt\")\n",
    "    # load parameter to the network\n",
    "    load_param_into_net(network, param_dict)\n",
    "    # load testing dataset\n",
    "    ds_eval = create_dataset(os.path.join(mnist_path, \"test\"))\n",
    "    acc = model.eval(ds_eval, dataset_sink_mode=False)\n",
    "    print(\"============== Accuracy:{} ==============\".format(acc))\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'args' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-4-ee92e83d0953>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     24\u001b[0m     \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnetwork\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnet_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnet_opt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m\"Accuracy\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAccuracy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m     \u001b[0mtrain_net\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmnist_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrepeat_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mckpoint_cb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     27\u001b[0m     \u001b[0mtest_net\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnetwork\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmnist_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'args' is not defined"
     ]
    }
   ],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    #parser = argparse.ArgumentParser(description='MindSpore LeNet Example')\n",
    "    #parser.add_argument('--device_target', type=str, default=\"CPU\", choices=['Ascend', 'GPU', 'CPU'],help='device where the code will be implemented (default: Ascend)')\n",
    "    #args = parser.parse_args()\n",
    "    context.set_context(mode=context.GRAPH_MODE, device_target=\"CPU\",enable_mem_reuse=False)\n",
    "    # download mnist dataset\n",
    "    #download_dataset()\n",
    "    # learning rate setting\n",
    "    lr = 0.01\n",
    "    momentum = 0.9\n",
    "    epoch_size = 1\n",
    "    mnist_path = \"./MNIST_Data\"\n",
    "    # define the loss function\n",
    "    net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n",
    "    repeat_size = epoch_size\n",
    "    # create the network\n",
    "    network = LeNet5()\n",
    "    # define the optimizer\n",
    "    net_opt = nn.Momentum(network.trainable_params(), lr, momentum)\n",
    "    config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)\n",
    "    # save the network model and parameters for subsequence fine-tuning\n",
    "    ckpoint_cb = ModelCheckpoint(prefix=\"checkpoint_lenet\", config=config_ck)\n",
    "    # group layers into an object with training and evaluation features\n",
    "    model = Model(network, net_loss, net_opt, metrics={\"Accuracy\": Accuracy()})\n",
    "\n",
    "    train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb)\n",
    "    test_net(args, network, model, mnist_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
